from typing import List

import numpy as np
import torch
from numpy import ndarray
from torch.optim import Adam

from CatModel import CatGPU


def local_score_cat(Data: ndarray, Xi: int, PAi: List[int], param = None):
    """
    Calculate the local score based on mutual information

    Parameters
    ----------
    Data: (sample, features)
    Xi: current index
    PAi: parent indexes
    parameters: None

    Returns
    -------
    score: local score
    """
    # var_idx = parameters['var_idx']
    X = torch.tensor(Data[:, Xi])
    if(len(PAi)):
        PA_idx = []
        for i in range(len(PAi)):
            PA_idx += PAi[i]
        PA = torch.tensor(Data[:, PAi])
    else:
        PA = torch.tensor(np.zeros((X.shape[0], 1)))

    dgp = CatGPU(PA, X, param)
    optim = Adam(dgp.parameters(), lr=0.1)

    score = 0
    for i in range(param['epochs']):
        d_train = dgp.train_step(PA, X, optim)
        score = d_train['score']
    return score
